# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570

from jax._src.lax.lax import (
  DotDimensionNumbers as DotDimensionNumbers,
  RaggedDotDimensionNumbers as RaggedDotDimensionNumbers,
  AccuracyMode as AccuracyMode,
  Tolerance as Tolerance,
  Precision as Precision,
  PrecisionLike as PrecisionLike,
  DotAlgorithm as DotAlgorithm,
  DotAlgorithmPreset as DotAlgorithmPreset,
  RandomAlgorithm as RandomAlgorithm,
  RoundingMethod as RoundingMethod,
  abs as abs,
  abs_p as abs_p,
  acos as acos,
  acos_p as acos_p,
  acosh as acosh,
  acosh_p as acosh_p,
  add as add,
  add_p as add_p,
  after_all as after_all,
  after_all_p as after_all_p,
  and_p as and_p,
  argmax as argmax,
  argmax_p as argmax_p,
  argmin as argmin,
  argmin_p as argmin_p,
  asin as asin,
  asin_p as asin_p,
  asinh as asinh,
  asinh_p as asinh_p,
  atan as atan,
  atan_p as atan_p,
  atan2 as atan2,
  atan2_p as atan2_p,
  atanh as atanh,
  atanh_p as atanh_p,
  batch_matmul as batch_matmul,
  bitcast_convert_type as bitcast_convert_type,
  bitcast_convert_type_p as bitcast_convert_type_p,
  bitwise_and as bitwise_and,
  bitwise_not as bitwise_not,
  bitwise_or as bitwise_or,
  bitwise_xor as bitwise_xor,
  broadcast as broadcast,
  broadcast_in_dim as broadcast_in_dim,
  broadcast_in_dim_p as broadcast_in_dim_p,
  broadcast_shapes as broadcast_shapes,
  broadcast_to_rank as broadcast_to_rank,
  broadcasted_iota as broadcasted_iota,
  cbrt as cbrt,
  cbrt_p as cbrt_p,
  ceil as ceil,
  ceil_p as ceil_p,
  clamp as clamp,
  clamp_p as clamp_p,
  clz as clz,
  clz_p as clz_p,
  collapse as collapse,
  complex as complex,
  complex_p as complex_p,
  composite as composite,
  concatenate as concatenate,
  concatenate_p as concatenate_p,
  conj as conj,
  conj_p as conj_p,
  convert_element_type as convert_element_type,
  convert_element_type_p as convert_element_type_p,
  copy_p as copy_p,
  cos as cos,
  dce_sink_p as dce_sink_p,
  dce_sink as dce_sink,
  cos_p as cos_p,
  cosh as cosh,
  cosh_p as cosh_p,
  create_token as create_token,
  create_token_p as create_token_p,
  div as div,
  div_p as div_p,
  dot as dot,
  dot_general as dot_general,
  dot_general_p as dot_general_p,
  dtype as dtype,
  eq as eq,
  eq_p as eq_p,
  eq_to_p as eq_to_p,
  exp as exp,
  exp_p as exp_p,
  exp2 as exp2,
  exp2_p as exp2_p,
  expand_dims as expand_dims,
  expm1 as expm1,
  expm1_p as expm1_p,
  floor as floor,
  floor_p as floor_p,
  full as full,
  full_like as full_like,
  ge as ge,
  ge_p as ge_p,
  gt as gt,
  gt_p as gt_p,
  imag as imag,
  imag_p as imag_p,
  integer_pow as integer_pow,
  integer_pow_p as integer_pow_p,
  iota as iota,
  iota_p as iota_p,
  is_finite as is_finite,
  is_finite_p as is_finite_p,
  le as le,
  le_p as le_p,
  le_to_p as le_to_p,
  log as log,
  log1p as log1p,
  log1p_p as log1p_p,
  log_p as log_p,
  logistic as logistic,
  logistic_p as logistic_p,
  lt as lt,
  lt_p as lt_p,
  lt_to_p as lt_to_p,
  max as max,
  max_p as max_p,
  min as min,
  min_p as min_p,
  mul as mul,
  mul_p as mul_p,
  ne as ne,
  ne_p as ne_p,
  neg as neg,
  neg_p as neg_p,
  nextafter as nextafter,
  nextafter_p as nextafter_p,
  not_p as not_p,
  optimization_barrier as optimization_barrier,
  optimization_barrier_p as optimization_barrier_p,
  or_p as or_p,
  pad as pad,
  pad_p as pad_p,
  padtype_to_pads as padtype_to_pads,
  population_count as population_count,
  population_count_p as population_count_p,
  pow as pow,
  pow_p as pow_p,
  ragged_dot as ragged_dot,
  ragged_dot_general as ragged_dot_general,
  real as real,
  real_p as real_p,
  reciprocal as reciprocal,
  reduce as reduce,
  reduce_and as reduce_and,
  reduce_and_p as reduce_and_p,
  reduce_max as reduce_max,
  reduce_max_p as reduce_max_p,
  reduce_min as reduce_min,
  reduce_min_p as reduce_min_p,
  reduce_or as reduce_or,
  reduce_or_p as reduce_or_p,
  reduce_p as reduce_p,
  reduce_precision as reduce_precision,
  reduce_precision_p as reduce_precision_p,
  reduce_prod as reduce_prod,
  reduce_prod_p as reduce_prod_p,
  reduce_sum as reduce_sum,
  reduce_sum_p as reduce_sum_p,
  reduce_xor as reduce_xor,
  reduce_xor_p as reduce_xor_p,
  rem as rem,
  rem_p as rem_p,
  reshape as reshape,
  reshape_p as reshape_p,
  rev as rev,
  rev_p as rev_p,
  rng_bit_generator as rng_bit_generator,
  rng_bit_generator_p as rng_bit_generator_p,
  rng_uniform as rng_uniform,
  rng_uniform_p as rng_uniform_p,
  round as round,
  round_p as round_p,
  rsqrt as rsqrt,
  rsqrt_p as rsqrt_p,
  select as select,
  select_n as select_n,
  select_n_p as select_n_p,
  shape_as_value as shape_as_value,
  shift_left as shift_left,
  shift_left_p as shift_left_p,
  shift_right_arithmetic as shift_right_arithmetic,
  shift_right_arithmetic_p as shift_right_arithmetic_p,
  shift_right_logical as shift_right_logical,
  shift_right_logical_p as shift_right_logical_p,
  sign as sign,
  sign_p as sign_p,
  sin as sin,
  sin_p as sin_p,
  sinh as sinh,
  sinh_p as sinh_p,
  sort as sort,
  sort_key_val as sort_key_val,
  sort_p as sort_p,
  split as split,
  split_p as split_p,
  sqrt as sqrt,
  sqrt_p as sqrt_p,
  square as square,
  square_p as square_p,
  squeeze as squeeze,
  squeeze_p as squeeze_p,
  stop_gradient as stop_gradient,
  sub as sub,
  sub_p as sub_p,
  tan as tan,
  tan_p as tan_p,
  tanh as tanh,
  tanh_p as tanh_p,
  top_k as top_k,
  top_k_p as top_k_p,
  transpose as transpose,
  transpose_p as transpose_p,
  xor_p as xor_p,
  empty as empty,
  zeros_like_array as _deprecated_zeros_like_array,
)
from jax._src.lax.special import (
  bessel_i0e as bessel_i0e,
  bessel_i0e_p as bessel_i0e_p,
  bessel_i1e as bessel_i1e,
  bessel_i1e_p as bessel_i1e_p,
  betainc as betainc,
  digamma as digamma,
  digamma_p as digamma_p,
  erf as erf,
  erfc as erfc,
  erfc_p as erfc_p,
  erf_inv as erf_inv,
  erf_inv_p as erf_inv_p,
  erf_p as erf_p,
  igamma as igamma,
  igammac as igammac,
  igammac_p as igammac_p,
  igamma_grad_a as igamma_grad_a,
  igamma_grad_a_p as igamma_grad_a_p,
  igamma_p as igamma_p,
  lgamma as lgamma,
  lgamma_p as lgamma_p,
  polygamma as polygamma,
  polygamma_p as polygamma_p,
  random_gamma_grad as random_gamma_grad,
  regularized_incomplete_beta_p as regularized_incomplete_beta_p,
  zeta as zeta,
  zeta_p as zeta_p,
)
from jax._src.lax.slicing import (
  GatherDimensionNumbers as GatherDimensionNumbers,
  GatherScatterMode as GatherScatterMode,
  ScatterDimensionNumbers as ScatterDimensionNumbers,
  dynamic_index_in_dim as dynamic_index_in_dim,
  dynamic_slice as dynamic_slice,
  dynamic_slice_in_dim as dynamic_slice_in_dim,
  dynamic_slice_p as dynamic_slice_p,
  dynamic_update_index_in_dim as dynamic_update_index_in_dim,
  dynamic_update_slice as dynamic_update_slice,
  dynamic_update_slice_in_dim as dynamic_update_slice_in_dim,
  dynamic_update_slice_p as dynamic_update_slice_p,
  gather as gather,
  gather_p as gather_p,
  index_in_dim as index_in_dim,
  index_take as index_take,
  scatter as scatter,
  scatter_apply as scatter_apply,
  scatter_add as scatter_add,
  scatter_add_p as scatter_add_p,
  scatter_max as scatter_max,
  scatter_max_p as scatter_max_p,
  scatter_min as scatter_min,
  scatter_min_p as scatter_min_p,
  scatter_mul as scatter_mul,
  scatter_mul_p as scatter_mul_p,
  scatter_p as scatter_p,
  scatter_sub as scatter_sub,
  scatter_sub_p as scatter_sub_p,
  slice as slice,
  slice_in_dim as slice_in_dim,
  slice_p as slice_p,
)
from jax._src.lax.convolution import (
  ConvDimensionNumbers as ConvDimensionNumbers,
  ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers,
  conv as conv,
  conv_dimension_numbers as conv_dimension_numbers,
  conv_general_dilated as conv_general_dilated,
  conv_general_dilated_p as conv_general_dilated_p,
  conv_general_permutations as conv_general_permutations,
  conv_general_shape_tuple as conv_general_shape_tuple,
  conv_shape_tuple as conv_shape_tuple,
  conv_transpose as conv_transpose,
  conv_transpose_shape_tuple as conv_transpose_shape_tuple,
  conv_with_general_padding as conv_with_general_padding,
)
from jax._src.lax.windowed_reductions import (
  reduce_window as reduce_window,
  reduce_window_max_p as reduce_window_max_p,
  reduce_window_min_p as reduce_window_min_p,
  reduce_window_p as reduce_window_p,
  reduce_window_shape_tuple as reduce_window_shape_tuple,
  reduce_window_sum_p as reduce_window_sum_p,
  select_and_gather_add_p as select_and_gather_add_p,
  select_and_scatter_p as select_and_scatter_p,
  select_and_scatter_add_p as select_and_scatter_add_p,
)
from jax._src.lax.control_flow import (
  associative_scan as associative_scan,
  cond as cond,
  cond_p as cond_p,
  cumlogsumexp as cumlogsumexp,
  cumlogsumexp_p as cumlogsumexp_p,
  cummax as cummax,
  cummax_p as cummax_p,
  cummin as cummin,
  cummin_p as cummin_p,
  cumprod as cumprod,
  cumprod_p as cumprod_p,
  cumsum as cumsum,
  cumsum_p as cumsum_p,
  custom_linear_solve as custom_linear_solve,
  custom_root as custom_root,
  fori_loop as fori_loop,
  linear_solve_p as linear_solve_p,
  map as map,
  scan as scan,
  scan_p as scan_p,
  switch as switch,
  while_loop as while_loop,
  while_p as while_p,
  platform_dependent as platform_dependent,
)
from jax._src.lax.fft import (
  fft as fft,
  fft_p as fft_p,
  FftType as FftType,
)
from jax._src.lax.parallel import (
  all_gather as all_gather,
  pcast as pcast,
  all_gather_p as all_gather_p,
  all_to_all as all_to_all,
  all_to_all_p as all_to_all_p,
  axis_index as axis_index,
  axis_index_p as axis_index_p,
  axis_size as axis_size,
  pbroadcast as pbroadcast,
  pmax as pmax,
  pmax_p as pmax_p,
  pmean as pmean,
  pmin as pmin,
  pmin_p as pmin_p,
  ppermute as ppermute,
  ppermute_p as ppermute_p,
  psend as psend,
  precv as precv,
  pshuffle as pshuffle,
  psum as psum,
  psum_p as psum_p,
  psum_scatter as psum_scatter,
  pswapaxes as pswapaxes,
  ragged_all_to_all as ragged_all_to_all,
  ragged_all_to_all_p as ragged_all_to_all_p,
)
from jax._src.core import (
    pvary as _deprecated_pvary,
)
from jax._src.lax.other import (
  conv_general_dilated_local as conv_general_dilated_local,
  conv_general_dilated_patches as conv_general_dilated_patches
)
from jax._src.lax.ann import (
  approx_max_k as approx_max_k,
  approx_min_k as approx_min_k,
  approx_top_k_p as approx_top_k_p
)
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from jax.lax import linalg as linalg

from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
from jax._src.dispatch import device_put_p as device_put_p

_deprecations = {
    # Added on July 24th 2025.
    "zeros_like_array": (
        (
            "jax.lax.zeros_like_array is deprecated. Use jax.numpy.zeros_like"
            " instead."
        ),
        _deprecated_zeros_like_array,
    ),
    # Added on Dec 1, 2025
    "pvary": (
        "jax.lax.pvary is deprecated. Use `jax.lax.pcast(..., to='varying')",
        _deprecated_pvary,
    ),
}

import typing as _typing
if _typing.TYPE_CHECKING:
  zeros_like_array = _deprecated_zeros_like_array
  pvary = _deprecated_pvary
else:
  from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
  __getattr__ = _deprecation_getattr(__name__, _deprecations)
  del _deprecation_getattr
del _deprecated_zeros_like_array, _deprecated_pvary
del _typing
